{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## relation extraction 实践\n",
    "> Tutorial作者：余海阳（yuhaiyang@zju.edu.cn)\n",
    "\n",
    "在这个演示中，我们使用 `pcnn` 模型实现中文关系抽取。\n",
    "希望在这个demo中帮助大家了解知识图谱构建过程中，三元组抽取构建的原理和常用方法。\n",
    "\n",
    "本demo使用 `python3` 运⾏。\n",
    "\n",
    "### 数据集\n",
    "在这个示例中，我们采样了一些中文文本，抽取其中的三元组。\n",
    "\n",
    "sentence|relation|head|tail\n",
    ":---:|:---:|:---:|:---:\n",
    "孔正锡在2005年以一部温馨的爱情电影《长腿叔叔》敲开电影界大门。|导演|长腿叔叔|孔正锡\n",
    "《伤心的树》是吴宗宪的音乐作品，收录在《你比从前快乐》专辑中。|所属专辑|伤心的树|你比从前快乐\n",
    "2000年8月，「天坛大佛」荣获「香港十大杰出工程项目」第四名。|所在城市|天坛大佛|香港\n",
    "\n",
    "\n",
    "- train.csv: 包含6个训练三元组，文件的每一⾏表示一个三元组, 按句子、关系、头实体、尾实体排序，并用`,`分隔。\n",
    "- valid.csv: 包含3个验证三元组，文件的每一⾏表示一个三元组, 按句子、关系、头实体、尾实体排序，并用`,`分隔。\n",
    "- test.csv:  包含3个测试三元组，文件的每一⾏表示一个三元组, 按句子、关系、头实体、尾实体排序，并用`,`分隔。\n",
    "- relation.csv: 包含4种关系三元组，文件的每一⾏表示一个三元组种类, 按头实体种类、尾实体种类、关系、序号排序，并用`,`分隔。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PCNN 原理回顾\n",
    "\n",
    "![PCNN](img/PCNN.jpg)\n",
    "\n",
    "句子信息主要包括word embedding和position embedding，在经过卷积层后，按照head tail的位置，分为三段最大池化，然后经过全连接层，即可得到句子的关系信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用pytorch运行神经网络，运行前确认是否安装\n",
    "!pip install torch\n",
    "!pip install matplotlib\n",
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入所使用模块\n",
    "import os\n",
    "import csv\n",
    "import math\n",
    "import pickle\n",
    "import logging\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from torch import optim\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "from torch.utils.data import Dataset,DataLoader\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n",
    "from transformers import BertTokenizer, BertModel\n",
    "\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型调参的配置文件\n",
    "# use_pcnn 参数控制是否是有piece_wise 的方式池化\n",
    "class Config(object):\n",
    "    model_name = 'cnn'  # ['cnn', 'gcn', 'lm']\n",
    "    use_pcnn = True\n",
    "    min_freq = 1\n",
    "    pos_limit = 20\n",
    "    out_path = 'data/out'   \n",
    "    batch_size = 2  \n",
    "    word_dim = 10\n",
    "    pos_dim = 5\n",
    "    dim_strategy = 'sum'  # ['sum', 'cat']\n",
    "    out_channels = 20\n",
    "    intermediate = 10\n",
    "    kernel_sizes = [3, 5, 7]\n",
    "    activation = 'gelu'\n",
    "    pooling_strategy = 'max'\n",
    "    dropout = 0.3\n",
    "    epoch = 10\n",
    "    num_relations = 4\n",
    "    learning_rate = 3e-4\n",
    "    lr_factor = 0.7 # 学习率的衰减率\n",
    "    lr_patience = 3 # 学习率衰减的等待epoch\n",
    "    weight_decay = 1e-3 # L2正则\n",
    "    early_stopping_patience = 6\n",
    "    train_log = True\n",
    "    log_interval = 1\n",
    "    show_plot = True\n",
    "    only_comparison_plot = False\n",
    "    plot_utils = 'matplot'\n",
    "    lm_file = 'bert-base-chinese'\n",
    "    lm_num_hidden_layers = 2\n",
    "    rnn_layers = 2\n",
    "    \n",
    "cfg = Config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# word token 构建 one-hot 词典，后续输入到embedding层得到对应word信息矩阵\n",
    "# 一般默认0为pad，1为unknown\n",
    "class Vocab(object):\n",
    "    def __init__(self, name: str = 'basic', init_tokens = [\"[PAD]\", \"[UNK]\"]):\n",
    "        self.name = name\n",
    "        self.init_tokens = init_tokens\n",
    "        self.trimed = False\n",
    "        self.word2idx = {}\n",
    "        self.word2count = {}\n",
    "        self.idx2word = {}\n",
    "        self.count = 0\n",
    "        self._add_init_tokens()\n",
    "\n",
    "    def _add_init_tokens(self):\n",
    "        for token in self.init_tokens:\n",
    "            self._add_word(token)\n",
    "\n",
    "    def _add_word(self, word: str):\n",
    "        if word not in self.word2idx:\n",
    "            self.word2idx[word] = self.count\n",
    "            self.word2count[word] = 1\n",
    "            self.idx2word[self.count] = word\n",
    "            self.count += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "\n",
    "    def add_words(self, words: Sequence):\n",
    "        for word in words:\n",
    "            self._add_word(word)\n",
    "\n",
    "    def trim(self, min_freq=2, verbose: Optional[bool] = True):\n",
    "        '''当 word 词频低于 min_freq 时，从词库中删除\n",
    "\n",
    "        Args:\n",
    "            param min_freq: 最低词频\n",
    "        '''\n",
    "        assert min_freq == int(min_freq), f'min_freq must be integer, can\\'t be {min_freq}'\n",
    "        min_freq = int(min_freq)\n",
    "        if min_freq < 2:\n",
    "            return\n",
    "        if self.trimed:\n",
    "            return\n",
    "        self.trimed = True\n",
    "\n",
    "        keep_words = []\n",
    "        new_words = []\n",
    "\n",
    "        for k, v in self.word2count.items():\n",
    "            if v >= min_freq:\n",
    "                keep_words.append(k)\n",
    "                new_words.extend([k] * v)\n",
    "        if verbose:\n",
    "            before_len = len(keep_words)\n",
    "            after_len = len(self.word2idx) - len(self.init_tokens)\n",
    "            logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(before_len, after_len, before_len / after_len * 100))\n",
    "\n",
    "        # Reinitialize dictionaries\n",
    "        self.word2idx = {}\n",
    "        self.word2count = {}\n",
    "        self.idx2word = {}\n",
    "        self.count = 0\n",
    "        self._add_init_tokens()\n",
    "        self.add_words(new_words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预处理过程所需要使用的函数\n",
    "Path = str\n",
    "\n",
    "def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n",
    "    if verbose:\n",
    "        logger.info(f'load csv from {fp}')\n",
    "\n",
    "    dialect = 'excel-tab' if is_tsv else 'excel'\n",
    "    with open(fp, encoding='utf-8') as f:\n",
    "        reader = csv.DictReader(f, dialect=dialect)\n",
    "        return list(reader)\n",
    "\n",
    "    \n",
    "def load_pkl(fp: Path, verbose: bool = True) -> Any:\n",
    "    if verbose:\n",
    "        logger.info(f'load data from {fp}')\n",
    "\n",
    "    with open(fp, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "        return data\n",
    "\n",
    "\n",
    "def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n",
    "    if verbose:\n",
    "        logger.info(f'save data in {fp}')\n",
    "\n",
    "    with open(fp, 'wb') as f:\n",
    "        pickle.dump(data, f)\n",
    "    \n",
    "    \n",
    "def _handle_relation_data(relation_data: List[Dict]) -> Dict:\n",
    "    rels = dict()\n",
    "    for d in relation_data:\n",
    "        rels[d['relation']] = {\n",
    "            'index': int(d['index']),\n",
    "            'head_type': d['head_type'],\n",
    "            'tail_type': d['tail_type'],\n",
    "        }\n",
    "    return rels\n",
    "\n",
    "\n",
    "def _add_relation_data(rels: Dict,data: List) -> None:\n",
    "    for d in data:\n",
    "        d['rel2idx'] = rels[d['relation']]['index']\n",
    "        d['head_type'] = rels[d['relation']]['head_type']\n",
    "        d['tail_type'] = rels[d['relation']]['tail_type']\n",
    "\n",
    "\n",
    "def _convert_tokens_into_index(data: List[Dict], vocab):\n",
    "    unk_str = '[UNK]'\n",
    "    unk_idx = vocab.word2idx[unk_str]\n",
    "\n",
    "    for d in data:\n",
    "        d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]\n",
    "\n",
    "\n",
    "def _add_pos_seq(train_data: List[Dict], cfg):\n",
    "    for d in train_data:\n",
    "        d['head_offset'], d['tail_offset'], d['lens'] = int(d['head_offset']), int(d['tail_offset']), int(d['lens'])\n",
    "        entities_idx = [d['head_offset'], d['tail_offset']] if d['head_offset'] < d['tail_offset'] else [d['tail_offset'], d['head_offset']]\n",
    "\n",
    "        d['head_pos'] = list(map(lambda i: i - d['head_offset'], list(range(d['lens']))))\n",
    "        d['head_pos'] = _handle_pos_limit(d['head_pos'], int(cfg.pos_limit))\n",
    "\n",
    "        d['tail_pos'] = list(map(lambda i: i - d['tail_offset'], list(range(d['lens']))))\n",
    "        d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))\n",
    "\n",
    "        if cfg.use_pcnn:\n",
    "            d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\\\n",
    "                                [3] * (d['lens'] - entities_idx[1])\n",
    "\n",
    "            \n",
    "def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:\n",
    "    for i, p in enumerate(pos):\n",
    "        if p > limit:\n",
    "            pos[i] = limit\n",
    "        if p < -limit:\n",
    "            pos[i] = -limit\n",
    "    return [p + limit + 1 for p in pos]\n",
    "\n",
    "\n",
    "def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n",
    "    \"\"\"\n",
    "    将一个表示sequence length的一维数组转换为二维的mask，默认pad的位置为1。\n",
    "    转变 1-d seq_len到2-d mask.\n",
    "\n",
    "    :param list, np.ndarray, torch.LongTensor seq_len: shape将是(B,)\n",
    "    :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有\n",
    "        区别，所以需要传入一个max_len使得mask的长度是pad到该长度。\n",
    "    :return: np.ndarray, torch.Tensor 。shape将是(B, max_length)， 元素类似为bool或torch.uint8\n",
    "    \"\"\"\n",
    "    if isinstance(seq_len, list):\n",
    "        seq_len = np.array(seq_len)\n",
    "\n",
    "    if isinstance(seq_len, np.ndarray):\n",
    "        seq_len = torch.from_numpy(seq_len)\n",
    "\n",
    "    if isinstance(seq_len, torch.Tensor):\n",
    "        assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n",
    "        batch_size = seq_len.size(0)\n",
    "        max_len = int(max_len) if max_len else seq_len.max().long()\n",
    "        broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n",
    "        if mask_pos_to_true:\n",
    "            mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n",
    "        else:\n",
    "            mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n",
    "    else:\n",
    "        raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n",
    "\n",
    "    return mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预处理过程\n",
    "logger.info('load raw files...')\n",
    "train_fp = os.path.join('data/train.csv')\n",
    "valid_fp = os.path.join('data/valid.csv')\n",
    "test_fp = os.path.join('data/test.csv')\n",
    "relation_fp = os.path.join('data/relation.csv')\n",
    "\n",
    "train_data = load_csv(train_fp)\n",
    "valid_data = load_csv(valid_fp)\n",
    "test_data = load_csv(test_fp)\n",
    "relation_data = load_csv(relation_fp)\n",
    "\n",
    "for d in train_data:\n",
    "    d['tokens'] = eval(d['tokens'])\n",
    "for d in valid_data:\n",
    "    d['tokens'] = eval(d['tokens'])\n",
    "for d in test_data:\n",
    "    d['tokens'] = eval(d['tokens'])\n",
    "    \n",
    "logger.info('convert relation into index...')\n",
    "rels = _handle_relation_data(relation_data)\n",
    "_add_relation_data(rels, train_data)\n",
    "_add_relation_data(rels, valid_data)\n",
    "_add_relation_data(rels, test_data)\n",
    "\n",
    "logger.info('verify whether use pretrained language models...')\n",
    "logger.info('build vocabulary...')\n",
    "vocab = Vocab('word')\n",
    "train_tokens = [d['tokens'] for d in train_data]\n",
    "valid_tokens = [d['tokens'] for d in valid_data]\n",
    "test_tokens = [d['tokens'] for d in test_data]\n",
    "sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]\n",
    "for sent in sent_tokens:\n",
    "    vocab.add_words(sent)\n",
    "vocab.trim(min_freq=cfg.min_freq)\n",
    "\n",
    "logger.info('convert tokens into index...')\n",
    "_convert_tokens_into_index(train_data, vocab)\n",
    "_convert_tokens_into_index(valid_data, vocab)\n",
    "_convert_tokens_into_index(test_data, vocab)\n",
    "\n",
    "logger.info('build position sequence...')\n",
    "_add_pos_seq(train_data, cfg)\n",
    "_add_pos_seq(valid_data, cfg)\n",
    "_add_pos_seq(test_data, cfg)\n",
    "\n",
    "logger.info('save data for backup...')\n",
    "os.makedirs(cfg.out_path, exist_ok=True)\n",
    "train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n",
    "valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n",
    "test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n",
    "save_pkl(train_data, train_save_fp)\n",
    "save_pkl(valid_data, valid_save_fp)\n",
    "save_pkl(test_data, test_save_fp)\n",
    "\n",
    "vocab_save_fp = os.path.join(cfg.out_path, 'vocab.pkl')\n",
    "vocab_txt = os.path.join(cfg.out_path, 'vocab.txt')\n",
    "save_pkl(vocab, vocab_save_fp)\n",
    "logger.info('save vocab in txt file, for watching...')\n",
    "with open(vocab_txt, 'w', encoding='utf-8') as f:\n",
    "    f.write(os.linesep.join(vocab.word2idx.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pytorch 构建自定义 Dataset\n",
    "def collate_fn(cfg):\n",
    "    def collate_fn_intra(batch):\n",
    "        batch.sort(key=lambda data: int(data['lens']), reverse=True)\n",
    "        max_len = int(batch[0]['lens'])\n",
    "        \n",
    "        def _padding(x, max_len):\n",
    "            return x + [0] * (max_len - len(x))\n",
    "        \n",
    "        def _pad_adj(adj, max_len):\n",
    "            adj = np.array(adj)\n",
    "            pad_len = max_len - adj.shape[0]\n",
    "            for i in range(pad_len):\n",
    "                adj = np.insert(adj, adj.shape[-1], 0, axis=1)\n",
    "            for i in range(pad_len):\n",
    "                adj = np.insert(adj, adj.shape[0], 0, axis=0)\n",
    "            return adj\n",
    "        \n",
    "        x, y = dict(), []\n",
    "        word, word_len = [], []\n",
    "        head_pos, tail_pos = [], []\n",
    "        pcnn_mask = []\n",
    "        adj_matrix = []\n",
    "        for data in batch:\n",
    "            word.append(_padding(data['token2idx'], max_len))\n",
    "            word_len.append(int(data['lens']))\n",
    "            y.append(int(data['rel2idx']))\n",
    "            \n",
    "            if cfg.model_name != 'lm':\n",
    "                head_pos.append(_padding(data['head_pos'], max_len))\n",
    "                tail_pos.append(_padding(data['tail_pos'], max_len))\n",
    "                if cfg.model_name == 'gcn':\n",
    "                    head = eval(data['dependency'])\n",
    "                    adj = head_to_adj(head, directed=True, self_loop=True)\n",
    "                    adj_matrix.append(_pad_adj(adj, max_len))\n",
    "\n",
    "                if cfg.use_pcnn:\n",
    "                    pcnn_mask.append(_padding(data['entities_pos'], max_len))\n",
    "\n",
    "        x['word'] = torch.tensor(word)\n",
    "        x['lens'] = torch.tensor(word_len)\n",
    "        y = torch.tensor(y)\n",
    "        \n",
    "        if cfg.model_name != 'lm':\n",
    "            x['head_pos'] = torch.tensor(head_pos)\n",
    "            x['tail_pos'] = torch.tensor(tail_pos)\n",
    "            if cfg.model_name == 'gcn':\n",
    "                x['adj'] = torch.tensor(adj_matrix)\n",
    "            if cfg.model_name == 'cnn' and cfg.use_pcnn:\n",
    "                x['pcnn_mask'] = torch.tensor(pcnn_mask)\n",
    "\n",
    "        return x, y\n",
    "    \n",
    "    return collate_fn_intra\n",
    "\n",
    "\n",
    "class CustomDataset(Dataset):\n",
    "    \"\"\"默认使用 List 存储数据\"\"\"\n",
    "    def __init__(self, fp):\n",
    "        self.file = load_pkl(fp)\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        sample = self.file[item]\n",
    "        return sample\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding层\n",
    "class Embedding(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        \"\"\"\n",
    "        word embedding: 一般 0 为 padding\n",
    "        pos embedding:  一般 0 为 padding\n",
    "        dim_strategy: [cat, sum]  多个 embedding 是拼接还是相加\n",
    "        \"\"\"\n",
    "        super(Embedding, self).__init__()\n",
    "\n",
    "        # self.xxx = config.xxx\n",
    "        self.vocab_size = config.vocab_size\n",
    "        self.word_dim = config.word_dim\n",
    "        self.pos_size = config.pos_limit * 2 + 2\n",
    "        self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim\n",
    "        self.dim_strategy = config.dim_strategy\n",
    "\n",
    "        self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0)\n",
    "        self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
    "        self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
    "\n",
    "\n",
    "    def forward(self, *x):\n",
    "        word, head, tail = x\n",
    "        word_embedding = self.wordEmbed(word)\n",
    "        head_embedding = self.headPosEmbed(head)\n",
    "        tail_embedding = self.tailPosEmbed(tail)\n",
    "\n",
    "        if self.dim_strategy == 'cat':\n",
    "            return torch.cat((word_embedding,head_embedding, tail_embedding), -1)\n",
    "        elif self.dim_strategy == 'sum':\n",
    "            # 此时 pos_dim == word_dim\n",
    "            return word_embedding + head_embedding + tail_embedding\n",
    "        else:\n",
    "            raise Exception('dim_strategy must choose from [sum, cat]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gelu激活函数，transformer指定，效果比relu好\n",
    "class GELU(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GELU, self).__init__()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cnn 模型\n",
    "class CNN(nn.Module):\n",
    "    \"\"\"\n",
    "    nlp 里为了保证输出的句长 = 输入的句长，一般使用奇数 kernel_size，如 [3, 5, 7, 9]\n",
    "    此时，padding = k // 2\n",
    "    stride 一般为 1\n",
    "    \"\"\"\n",
    "    def __init__(self, config):\n",
    "        \"\"\"\n",
    "        in_channels      : 一般就是 word embedding 的维度，或者 hidden size 的维度\n",
    "        out_channels     : int\n",
    "        kernel_sizes     : list 为了保证输出长度=输入长度，必须为奇数: 3, 5, 7...\n",
    "        activation       : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]\n",
    "        pooling_strategy : [max, avg, cls]\n",
    "        dropout:         : float\n",
    "        \"\"\"\n",
    "        super(CNN, self).__init__()\n",
    "\n",
    "        # self.xxx = config.xxx\n",
    "        # self.in_channels = config.in_channels\n",
    "        if config.dim_strategy == 'cat':\n",
    "            self.in_channels = config.word_dim + 2 * config.pos_dim\n",
    "        else:\n",
    "            self.in_channels = config.word_dim\n",
    "\n",
    "        self.out_channels = config.out_channels\n",
    "        self.kernel_sizes = config.kernel_sizes\n",
    "        self.activation = config.activation\n",
    "        self.pooling_strategy = config.pooling_strategy\n",
    "        self.dropout = config.dropout\n",
    "        for kernel_size in self.kernel_sizes:\n",
    "            assert kernel_size % 2 == 1, \"kernel size has to be odd numbers.\"\n",
    "\n",
    "        # convolution\n",
    "        self.convs = nn.ModuleList([\n",
    "            nn.Conv1d(in_channels=self.in_channels,\n",
    "                      out_channels=self.out_channels,\n",
    "                      kernel_size=k,\n",
    "                      stride=1,\n",
    "                      padding=k // 2,\n",
    "                      dilation=1,\n",
    "                      groups=1,\n",
    "                      bias=False) for k in self.kernel_sizes\n",
    "        ])\n",
    "\n",
    "        # activation function\n",
    "        assert self.activation in ['relu', 'lrelu', 'prelu', 'selu', 'celu', 'gelu', 'sigmoid', 'tanh'], \\\n",
    "            'activation function must choose from [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]'\n",
    "        self.activations = nn.ModuleDict([\n",
    "            ['relu', nn.ReLU()],\n",
    "            ['lrelu', nn.LeakyReLU()],\n",
    "            ['prelu', nn.PReLU()],\n",
    "            ['selu', nn.SELU()],\n",
    "            ['celu', nn.CELU()],\n",
    "            ['gelu', GELU()],\n",
    "            ['sigmoid', nn.Sigmoid()],\n",
    "            ['tanh', nn.Tanh()],\n",
    "        ])\n",
    "\n",
    "        # pooling\n",
    "        assert self.pooling_strategy in ['max', 'avg', 'cls'], 'pooling strategy must choose from [max, avg, cls]'\n",
    "\n",
    "        self.dropout = nn.Dropout(self.dropout)\n",
    "\n",
    "    def forward(self, x, mask=None):\n",
    "        \"\"\"\n",
    "            :param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值\n",
    "            :param mask: [batch_size, max_len], 句长部分为0，padding部分为1。不影响卷积运算，max-pool一定不会pool到pad为0的位置\n",
    "            :return:\n",
    "            \"\"\"\n",
    "        # [B, L, H] -> [B, H, L] （注释：将 H 维度当作输入 channel 维度)\n",
    "        x = torch.transpose(x, 1, 2)\n",
    "\n",
    "        # convolution + activation  [[B, H, L], ... ]\n",
    "        act_fn = self.activations[self.activation]\n",
    "\n",
    "        x = [act_fn(conv(x)) for conv in self.convs]\n",
    "        x = torch.cat(x, dim=1)\n",
    "\n",
    "        # mask\n",
    "        if mask is not None:\n",
    "            # [B, L] -> [B, 1, L]\n",
    "            mask = mask.unsqueeze(1)\n",
    "            x = x.masked_fill_(mask, 1e-12)\n",
    "\n",
    "        # pooling\n",
    "        # [[B, H, L], ... ] -> [[B, H], ... ]\n",
    "        if self.pooling_strategy == 'max':\n",
    "            xp = F.max_pool1d(x, kernel_size=x.size(2)).squeeze(2)\n",
    "            # 等价于 xp = torch.max(x, dim=2)[0]\n",
    "\n",
    "        elif self.pooling_strategy == 'avg':\n",
    "            x_len = mask.squeeze().eq(0).sum(-1).unsqueeze(-1).to(torch.float).to(device=mask.device)\n",
    "            xp = torch.sum(x, dim=-1) / x_len\n",
    "\n",
    "        else:\n",
    "            # self.pooling_strategy == 'cls'\n",
    "            xp = x[:, :, 0]\n",
    "\n",
    "        x = x.transpose(1, 2)\n",
    "        x = self.dropout(x)\n",
    "        xp = self.dropout(xp)\n",
    "\n",
    "        return x, xp  # [B, L, Hs], [B, Hs]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pcnn 模型\n",
    "class PCNN(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super(PCNN, self).__init__()\n",
    "\n",
    "        self.use_pcnn = cfg.use_pcnn\n",
    "\n",
    "        self.embedding = Embedding(cfg)\n",
    "        self.cnn = CNN(cfg)\n",
    "        self.fc1 = nn.Linear(len(cfg.kernel_sizes) * cfg.out_channels, cfg.intermediate)\n",
    "        self.fc2 = nn.Linear(cfg.intermediate, cfg.num_relations)\n",
    "        self.dropout = nn.Dropout(cfg.dropout)\n",
    "\n",
    "        if self.use_pcnn:\n",
    "            self.fc_pcnn = nn.Linear(3 * len(cfg.kernel_sizes) * cfg.out_channels,\n",
    "                                     len(cfg.kernel_sizes) * cfg.out_channels)\n",
    "            self.pcnn_mask_embedding = nn.Embedding(4, 3)\n",
    "            masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0, 100]])\n",
    "            self.pcnn_mask_embedding.weight.data.copy_(masks)\n",
    "            self.pcnn_mask_embedding.weight.requires_grad = False\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']\n",
    "        mask = seq_len_to_mask(lens)\n",
    "\n",
    "        inputs = self.embedding(word, head_pos, tail_pos)\n",
    "        out, out_pool = self.cnn(inputs, mask=mask)\n",
    "\n",
    "        if self.use_pcnn:\n",
    "            out = out.unsqueeze(-1)  # [B, L, Hs, 1]\n",
    "            pcnn_mask = x['pcnn_mask']\n",
    "            pcnn_mask = self.pcnn_mask_embedding(pcnn_mask).unsqueeze(-2)  # [B, L, 1, 3]\n",
    "            out = out + pcnn_mask  # [B, L, Hs, 3]\n",
    "            out = out.max(dim=1)[0] - 100  # [B, Hs, 3]\n",
    "            out_pool = out.view(out.size(0), -1)  # [B, 3 * Hs]\n",
    "            out_pool = F.leaky_relu(self.fc_pcnn(out_pool))  # [B, Hs]\n",
    "            out_pool = self.dropout(out_pool)\n",
    "\n",
    "        output = self.fc1(out_pool)\n",
    "        output = F.leaky_relu(output)\n",
    "        output = self.dropout(output)\n",
    "        output = self.fc2(output)\n",
    "\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  p,r,f1 指标测量\n",
    "class PRMetric():\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        暂时调用 sklearn 的方法\n",
    "        \"\"\"\n",
    "        self.y_true = np.empty(0)\n",
    "        self.y_pred = np.empty(0)\n",
    "\n",
    "    def reset(self):\n",
    "        self.y_true = np.empty(0)\n",
    "        self.y_pred = np.empty(0)\n",
    "\n",
    "    def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n",
    "        y_true = y_true.cpu().detach().numpy()\n",
    "        y_pred = y_pred.cpu().detach().numpy()\n",
    "        y_pred = np.argmax(y_pred,axis=-1)\n",
    "\n",
    "        self.y_true = np.append(self.y_true, y_true)\n",
    "        self.y_pred = np.append(self.y_pred, y_pred)\n",
    "\n",
    "    def compute(self):\n",
    "        p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n",
    "        _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n",
    "\n",
    "        return acc,p,r,f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练过程中的迭代\n",
    "def train(epoch, model, dataloader, optimizer, criterion, cfg):\n",
    "    model.train()\n",
    "\n",
    "    metric = PRMetric()\n",
    "    losses = []\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x)\n",
    "        loss = criterion(y_pred, y)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        metric.update(y_true=y, y_pred=y_pred)\n",
    "        losses.append(loss.item())\n",
    "\n",
    "        data_total = len(dataloader.dataset)\n",
    "        data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n",
    "        if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n",
    "           # p r f1 皆为 macro，因为micro时三者相同，定义为acc\n",
    "            acc,p,r,f1 = metric.compute()\n",
    "            print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n",
    "                        f'Loss: {loss.item():.6f}')\n",
    "            print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n",
    "                        f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n",
    "\n",
    "    if cfg.show_plot and not cfg.only_comparison_plot:\n",
    "        if cfg.plot_utils == 'matplot':\n",
    "            plt.plot(losses)\n",
    "            plt.title(f'epoch {epoch} train loss')\n",
    "            plt.show()\n",
    "\n",
    "    return losses[-1]\n",
    "\n",
    "\n",
    "# 测试过程中的迭代\n",
    "def validate(epoch, model, dataloader, criterion,verbose=True):\n",
    "    model.eval()\n",
    "\n",
    "    metric = PRMetric()\n",
    "    losses = []\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
    "        with torch.no_grad():\n",
    "            y_pred = model(x)\n",
    "            loss = criterion(y_pred, y)\n",
    "\n",
    "            metric.update(y_true=y, y_pred=y_pred)\n",
    "            losses.append(loss.item())\n",
    "\n",
    "    loss = sum(losses) / len(losses)\n",
    "    acc,p,r,f1 = metric.compute()\n",
    "    data_total = len(dataloader.dataset)\n",
    "    if verbose:\n",
    "        print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n",
    "        print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n",
    "\n",
    "    return f1,loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载数据集\n",
    "train_dataset = CustomDataset(train_save_fp)\n",
    "valid_dataset = CustomDataset(valid_save_fp)\n",
    "test_dataset = CustomDataset(test_save_fp)\n",
    "\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
    "valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 因为加载预处理后的数据，才知道vocab_size\n",
    "vocab = load_pkl(vocab_save_fp)\n",
    "vocab_size = vocab.count\n",
    "cfg.vocab_size = vocab_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# main 入口，定义优化函数、loss函数等\n",
    "# 开始epoch迭代\n",
    "# 使用valid 数据集的loss做早停判断，当不再下降时，此时为模型泛化性最好的时刻。\n",
    "model = PCNN(cfg)\n",
    "print(model)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "best_f1, best_epoch = -1, 0\n",
    "es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n",
    "train_losses, valid_losses = [], []\n",
    "\n",
    "logger.info('=' * 10 + ' Start training ' + '=' * 10)\n",
    "for epoch in range(1, cfg.epoch + 1):\n",
    "    train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n",
    "    valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n",
    "    scheduler.step(valid_loss)\n",
    "\n",
    "    train_losses.append(train_loss)\n",
    "    valid_losses.append(valid_loss)\n",
    "    if best_f1 < valid_f1:\n",
    "        best_f1 = valid_f1\n",
    "        best_epoch = epoch\n",
    "    # 使用 valid loss 做 early stopping 的判断标准\n",
    "    if es_loss > valid_loss:\n",
    "        es_loss = valid_loss\n",
    "        es_f1 = valid_f1\n",
    "        best_es_f1 = valid_f1\n",
    "        es_epoch = epoch\n",
    "        best_es_epoch = epoch\n",
    "        es_patience = 0\n",
    "    else:\n",
    "        es_patience += 1\n",
    "        if es_patience >= cfg.early_stopping_patience:\n",
    "            best_es_epoch = es_epoch\n",
    "            best_es_f1 = es_f1\n",
    "\n",
    "if cfg.show_plot:\n",
    "    if cfg.plot_utils == 'matplot':\n",
    "        plt.plot(train_losses, 'x-')\n",
    "        plt.plot(valid_losses, '+-')\n",
    "        plt.legend(['train', 'valid'])\n",
    "        plt.title('train/valid comparison loss')\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n",
    "            f'this epoch macro f1: {best_es_f1:0.4f}')\n",
    "print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n",
    "            f'this epoch macro f1: {best_f1:.4f}')\n",
    "\n",
    "test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n",
    "print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本demo不包括调参部分，有兴趣的同学可以自行前往 [deepke](http://openkg.cn/tool/deepke) 仓库，下载使用更多的模型 :)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
